package ck_redis

import (
	"errors"
	"fmt"
	"github.com/gomodule/redigo/redis"
	"time"
)

// Lock 分布式锁对象
type Lock struct {
	resource string
	token    string
	timeout  int
	domain   string //具体使用哪个redis
}

func (lock *Lock) tryLock() (ok bool, err error) {
	_, err = redis.String(lock.do("SET", lock.key(), lock.token, "EX", lock.timeout, "NX"))
	if err == redis.ErrNil {
		return false, nil
	}
	if err != nil {
		return false, err
	}
	return true, nil
}

// Unlock 解锁
func (lock *Lock) Unlock() (err error) {
	var str string
	str, err = redis.String(lock.do("get", lock.key()))
	if str == lock.token {
		_, err = lock.do("del", lock.key())
	} else {
		err = errors.New("unlock failed")
	}
	return
}

func (lock *Lock) key() string {
	return fmt.Sprintf("redislock:%s", lock.resource)
}

// WaitForUnlock 等待锁
func (lock *Lock) WaitForUnlock(timeout time.Duration) (bool, error) {
	times := 0
	delay := time.Millisecond * 50
	maxTimes := int(timeout / delay)
	for {
		str, err := redis.String(lock.do("get", lock.key()))
		if err != nil {
			return false, err
		}
		if str == "" {
			return true, nil
		}
		times++
		if times > maxTimes {
			return false, errors.New("wait timeout")
		}
		time.Sleep(delay)
	}
}

// AddTimeout 在原基础上增加超时时间
func (lock *Lock) AddTimeout(exTime int64) (ok bool, err error) {
	ttlTime, err := redis.Int64(lock.do("TTL", lock.key()))
	if err != nil {
		return
	}
	if ttlTime > 0 {
		_, err := redis.String(lock.do("SET", lock.key(), lock.token, "EX", int(ttlTime+exTime)))
		if err == redis.ErrNil {
			return false, nil
		}
		if err != nil {
			return false, err
		}
	}
	return false, nil
}

// TryLock 尝试获取一个锁
func TryLock(resource string, token string, defaultTimeout int) (lock *Lock, ok bool, err error) {
	return TryLockWithTimeout(resource, token, defaultTimeout)
}

// TryLockWithTimeout 尝试获取一个锁并指定一个锁超时时间
func TryLockWithTimeout(resource string, token string, timeout int, domain ...string) (lock *Lock, ok bool, err error) {
	lock = &Lock{resource: resource, token: token, timeout: timeout, domain: "default"}
	if len(domain) > 0 {
		lock.domain = domain[0]
	}

	ok, err = lock.tryLock()

	if err != nil {
		lock = nil
	}

	return
}

func (lock *Lock) do(commandName string, args ...interface{}) (reply interface{}, err error) {
	if len(args) < 1 {
		return nil, errors.New("missing required arguments")
	}
	args[0] = lock.associate(args[0])
	return Redis(lock.domain).Do(commandName, args...)
}

// associate with config key.
func (lock *Lock) associate(originKey interface{}) string {
	return fmt.Sprintf("%s", originKey)
}
